#include "ToRORd_dynCl_mixed_endo_mid_epi.h"
#include <stdlib.h>

real max_step;
real min_step;
real abstol;
real reltol;
bool adpt;
real *ode_dt, *ode_previous_dt, *ode_time_new;

GET_CELL_MODEL_DATA(init_cell_model_data) {

    if(get_initial_v)
        cell_model->initial_v = INITIAL_V;
    if(get_neq)
        cell_model->number_of_ode_equations = NEQ;
}

SET_ODE_INITIAL_CONDITIONS_CPU(set_model_initial_conditions_cpu) {

    log_info("Using ToRORd_dynCl_mixed_endo_mid_epi CPU model\n");

    uint32_t num_cells = solver->original_num_cells;
    solver->sv = (real*)malloc(NEQ*num_cells*sizeof(real));

    max_step = solver->max_dt;
    min_step = solver->min_dt;
    abstol   = solver->abs_tol;
    reltol   = solver->rel_tol;
    adpt     = solver->adaptive;

    if(adpt)
    {
        ode_dt = (real*)malloc(num_cells*sizeof(real));

        OMP(parallel for)
        for(int i = 0; i < num_cells; i++)
            ode_dt[i] = solver->min_dt;

        ode_previous_dt = (real*)calloc(num_cells, sizeof(real));
        ode_time_new    = (real*)calloc(num_cells, sizeof(real));
        log_info("Using Adaptive Euler model to solve the ODEs\n");
    }
    else
    {
        log_info("Using Euler model to solve the ODEs\n");
    }

    real *initial_endo = NULL;
    real *initial_epi = NULL;
    real *initial_mid = NULL;
    real *mapping = NULL;
    if(solver->ode_extra_data)
    {
        initial_endo = (real *)solver->ode_extra_data;
        initial_epi = (real *)solver->ode_extra_data+NEQ;
        initial_mid = (real *)solver->ode_extra_data+NEQ+NEQ;
        mapping = (real *)solver->ode_extra_data+NEQ+NEQ+NEQ;
    }
    else
    {
        log_error_and_exit("You must supply a mask function to tag the cells when using this mixed model!\n");
    }

    OMP(parallel for)
    for(uint32_t i = 0; i < num_cells; i++)
    {
        real *sv = &solver->sv[i * NEQ];

        for (int j = 0; j < NEQ; j++)
        {
            if (mapping[i] == ENDO)
                sv[j] = initial_endo[j];
            else if (mapping[i] == EPI)
                sv[j] = initial_epi[j];
            else
                sv[j] = initial_mid[j];
        }
    }
}

SOLVE_MODEL_ODES(solve_model_odes_cpu) {

    uint32_t sv_id;

    size_t num_cells_to_solve = ode_solver->num_cells_to_solve;
    uint32_t * cells_to_solve = ode_solver->cells_to_solve;
    real *sv = ode_solver->sv;
    real dt = ode_solver->min_dt;
    uint32_t num_steps = ode_solver->num_steps;

    // Get the mapping array
    real *mapping = NULL;
    if (ode_solver->ode_extra_data) {
        mapping = (real *)ode_solver->ode_extra_data+NEQ+NEQ+NEQ;
    }
    else {
        log_error_and_exit("You must supply a mask function to tag the cells when using this mixed model!\n");
    }

    OMP(parallel for private(sv_id))
    for (u_int32_t i = 0; i < num_cells_to_solve; i++) {

        if(cells_to_solve)
            sv_id = cells_to_solve[i];
        else
            sv_id = i;

        if(adpt) {
            solve_forward_euler_cpu_adpt(sv + (sv_id * NEQ), stim_currents[i], mapping[i], current_t + dt, sv_id);
        }
        else {
            for (int j = 0; j < num_steps; ++j) {
                solve_model_ode_cpu(dt, sv + (sv_id * NEQ), stim_currents[i], mapping[i]);
            }
        }
    }
}

void solve_model_ode_cpu(real dt, real *sv, real stim_current, real mapping)  {

    assert(sv);

    real rY[NEQ], rDY[NEQ];

    for(int i = 0; i < NEQ; i++)
        rY[i] = sv[i];

    RHS_cpu(rY, rDY, stim_current, dt, mapping);

    // Full Explicit Euler
    for(int i = 0; i < NEQ; i++)
        sv[i] = dt*rDY[i] + rY[i];

    // Explicit Euler + RushLarsen
    //for(int i = 0; i < 12; i++)
    //    sv[i] = dt*rDY[i] + rY[i];
    //for(int i = 12; i < 34; i++)
    //    sv[i] = rDY[i];
    //for(int i = 34; i < 41; i++)
    //    sv[i] = dt*rDY[i] + rY[i];
    //for(int i = 41; i < NEQ; i++)
    //    sv[i] = rDY[i];
}

void solve_forward_euler_cpu_adpt(real *sv, real stim_curr, real mapping, real final_time, int sv_id) {

    const real _beta_safety_ = 0.8;
    int numEDO = NEQ;

    real rDY[numEDO];

    real _tolerances_[numEDO];
    real _aux_tol = 0.0;
    //initializes the variables
    ode_previous_dt[sv_id] = ode_dt[sv_id];

    real edos_old_aux_[numEDO];
    real edos_new_euler_[numEDO];
    real *_k1__ = (real*) malloc(sizeof(real)*numEDO);
    real *_k2__ = (real*) malloc(sizeof(real)*numEDO);
    real *_k_aux__;

    real *dt = &ode_dt[sv_id];
    real *time_new = &ode_time_new[sv_id];
    real *previous_dt = &ode_previous_dt[sv_id];

    if(*time_new + *dt > final_time) {
       *dt = final_time - *time_new;
    }

    RHS_cpu(sv, rDY, stim_curr, *dt, mapping);
    *time_new += *dt;

    for(int i = 0; i < numEDO; i++){
        _k1__[i] = rDY[i];
    }

    const double __tiny_ = pow(abstol, 2.0);

    int count = 0;

    int count_limit = (final_time - *time_new)/min_step;

    int aux_count_limit = count_limit+2000000;

    if(aux_count_limit > 0) {
        count_limit = aux_count_limit;
    }

    while(1) {

        for(int i = 0; i < numEDO; i++) {
            //stores the old variables in a vector
            edos_old_aux_[i] = sv[i];
            //computes euler method
            edos_new_euler_[i] = _k1__[i] * *dt + edos_old_aux_[i];
            //steps ahead to compute the rk2 method
            sv[i] = edos_new_euler_[i];
        }

        *time_new += *dt;
        RHS_cpu(sv, rDY, stim_curr, *dt, mapping);
        *time_new -= *dt;//step back

        double greatestError = 0.0, auxError = 0.0;
        for(int i = 0; i < numEDO; i++) {
            //stores the new evaluation
            _k2__[i] = rDY[i];
            _aux_tol = fabs(edos_new_euler_[i])*reltol;
            _tolerances_[i] = (abstol > _aux_tol )?abstol:_aux_tol;
            //finds the greatest error between  the steps
            auxError = fabs(( (*dt/2.0)*(_k1__[i] - _k2__[i])) / _tolerances_[i]);

            greatestError = (auxError > greatestError) ? auxError : greatestError;
        }
        ///adapt the time step
        greatestError += __tiny_;
        *previous_dt = *dt;
        ///adapt the time step
        *dt = _beta_safety_ * (*dt) * sqrt(1.0f/greatestError);

        if (*time_new + *dt > final_time) {
            *dt = final_time - *time_new;
        }

        //it doesn't accept the solution
        if ( count < count_limit  && (greatestError >= 1.0f)) {
            //restore the old values to do it again
            for(int i = 0;  i < numEDO; i++) {
                sv[i] = edos_old_aux_[i];
            }

            count++;
            //throw the results away and compute again
        } else{//it accepts the solutions


            if(greatestError >=1.0) {
                printf("Accepting solution with error > %lf \n", greatestError);
            }

            //printf("%e %e\n", _ode->time_new, edos_new_euler_[0]);
            if (*dt < min_step) {
                *dt = min_step;
            }

            else if (*dt > max_step && max_step != 0) {
                *dt = max_step;
            }

            if (*time_new + *dt > final_time) {
                *dt = final_time - *time_new;
            }

            _k_aux__ = _k2__;
            _k2__   = _k1__;
            _k1__   = _k_aux__;

            //it steps the method ahead, with euler solution
            for(int i = 0; i < numEDO; i++){
                sv[i] = edos_new_euler_[i];
            }

            if(*time_new + *previous_dt >= final_time){
                if((fabs(final_time - *time_new) < 1.0e-5) ){
                    break;
                }else if(*time_new < final_time){
                    *dt = *previous_dt = final_time - *time_new;
                    *time_new += *previous_dt;
                    break;

                }else{
                    printf("Error: time_new %.20lf final_time %.20lf diff %e \n", *time_new , final_time, fabs(final_time - *time_new) );
                    break;
                }
            }else{
                *time_new += *previous_dt;
            }

        }
    }

    free(_k1__);
    free(_k2__);
}

void RHS_cpu(const real *sv, real *rDY_, real stim_current, real dt, real mapping) {

    // Get the celltype for the current cell
    real celltype = mapping;

    // Get the stimulus current from the current cell
    real calc_I_stim = stim_current;

    // State variables
    real v = sv[0];
    real CaMKt = sv[1];
    real cass = sv[2];
    real nai = sv[3];
    real nass = sv[4];
    real ki = sv[5];
    real kss = sv[6];
    real cansr = sv[7];
    real cajsr = sv[8];
    real cai = sv[9];
    real cli = sv[10];
    real clss = sv[11];
    real m = sv[12];
    real h = sv[13];
    real j = sv[14];
    real hp = sv[15];
    real jp = sv[16];
    real mL = sv[17];
    real hL = sv[18];
    real hLp = sv[19];
    real a = sv[20];
    real iF = sv[21];
    real iS = sv[22];
    real ap = sv[23];
    real iFp = sv[24];
    real iSp = sv[25];
    real d = sv[26];
    real ff = sv[27];
    real fs = sv[28];
    real fcaf = sv[29];
    real fcas = sv[30];
    real jca = sv[31];
    real ffp = sv[32];
    real fcafp = sv[33];
    real nca_ss = sv[34];
    real nca_i = sv[35];
    real C1 = sv[36];
    real C2 = sv[37];
    real C3 = sv[38];
    real I = sv[39];
    real O = sv[40];
    real xs1 = sv[41];
    real xs2 = sv[42];
    real Jrel_np = sv[43];
    real Jrel_p = sv[44];
    
    #include "ToRORd_dynCl_mixed_endo_mid_epi.common.c"
}
